{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import HfArgumentParser\n",
    "import torch\n",
    "import transformers\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "from dataclasses import dataclass, field"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel, TrainingArguments, AutoConfig\n",
    "from modeling_chatglm import ChatGLMForConditionalGeneration\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from peft import get_peft_model, LoraConfig, TaskType\n",
    "\n",
    "\n",
    "torch.set_default_tensor_type(torch.cuda.HalfTensor)\n",
    "model = ChatGLMForConditionalGeneration.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True, device_map='auto')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "peft_path = \"output/chatglm-lora.pt\"\n",
    "\n",
    "peft_config = LoraConfig(\n",
    "    task_type=TaskType.CAUSAL_LM, inference_mode=False,\n",
    "    r=8,\n",
    "    lora_alpha=32, lora_dropout=0.1\n",
    ")\n",
    "\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.load_state_dict(torch.load(peft_path), strict=False)\n",
    "torch.set_default_tensor_type(torch.cuda.FloatTensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"THUDM/chatglm-6b\", trust_remote_code=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "\n",
    "instructions = json.load(open(\"data/alpaca_data.json\"))\n",
    "\n",
    "instructions = [\n",
    "    {\n",
    "        'instruction': \"下面哪个产品与其他不同\",\n",
    "        \"input\": \"知乎, 百度, 微博, 拼多多\",\n",
    "        \"output\": \"拼多多与其他产品不同，因为它是一个电子商务平台，而知乎、百度和微博都是社交平台/搜索引擎。\",\n",
    "    },\n",
    "    {\n",
    "        'instruction': \"下面哪个产品与其他不同\",\n",
    "        \"input\": \"chatgpt, 文心, CPM, 拼多多\",\n",
    "        \"output\": \"拼多多与其他产品不同，因为它是一家电子商务公司，而ChatGPT、文心和CPM都不是电子商务公司。\",\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "answers = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for idx, item in enumerate(instructions[:5]):\n",
    "        input_text = f\"### {idx+1}.Instruction:\\n{item['instruction']}\\n\\n\"\n",
    "        if item.get('input'):\n",
    "            input_text += f\"### {idx+1}.Input:\\n{item['input']}\\n\\n\"\n",
    "        input_text += f\"### {idx+1}.Response:\"\n",
    "        batch = tokenizer(input_text, return_tensors=\"pt\")\n",
    "        out = model.generate(\n",
    "            input_ids=batch[\"input_ids\"],\n",
    "            attention_mask=torch.ones_like(batch[\"input_ids\"]).bool(),\n",
    "            max_length=512,\n",
    "            temperature=0\n",
    "        )\n",
    "        out_text = tokenizer.decode(out[0])\n",
    "        answer = out_text.replace(input_text, \"\").replace(\"\\nEND\", \"\").strip()\n",
    "        item['infer_answer'] = answer\n",
    "        print(out_text)\n",
    "        print(f\"### {idx+1}.Answer:\\n\", item.get('output'), '\\n\\n')\n",
    "        answers.append({'index': idx, **item})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10 (default, Nov 14 2022, 12:59:47) \n[GCC 9.4.0]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "25273a2a68c96ebac13d7fb9e0db516f9be0772777a0507fe06d682a441a3ba7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
